# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import copy
import numpy as np 
from PIL import Image
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset

from semilearn.datasets.augmentation import RandAugment
from semilearn.datasets.utils import get_onehot
from semilearn.datasets.cv_datasets.common.utils import  join

class BasicDataset(Dataset):
    
    def __init__(self, alg, data, transform=None, strong_transform=None):
        super(BasicDataset, self).__init__()
        self.alg = alg
        self.data = data
        self.targets = None

        self.transform = transform
        self.strong_transform = strong_transform
    
    def __sample__(self, idx):
        """ dataset specific sample function """
        if self.targets is None:
            target = -1
        else:
            target_ = self.targets[idx]
            target = target_ if not self.onehot else get_onehot(self.num_classes, target_)

        img = Image.open(self.data[idx]).convert('RGB')
        return img, target

    def __getitem__(self, idx):
        img, target = self.__sample__(idx)
        
        if isinstance(img, np.ndarray):
            img = Image.fromarray(img)
        img_w = self.transform(img)
        return {'idx_ulb': idx, 'x_ulb_w': img_w, 'x_ulb_s': self.strong_transform(img), 'y_ulb': target} 


    def __len__(self):
        return len(self.data)